Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestions for pointwise_logdensities and siblings #669

Closed
wants to merge 14 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Sep 23, 2024

This is based on #663 but with the following changes:

  • Un-deprecated the pointwise_loglikelihoods (but it just uses pointwise_logdensities under the hood)
  • Added pointwise_prior_logdensities.
  • Unrolls the dot_tilde_assume so we can handle .~ correctly.

@bgctw I wanted to make a PR to your PR but doesn't seem possible due to being on your fork. I recommend looking at the diffs I've added. We can discuss the changes here and then you can incorporate those we want in your PR so that you get proper credit 👍

EDIT: Ignore all the formatting stuff. I'm deliberately not doing that to make the diff between this branch and @bgctw 's branch simpler to read.

Similarly, one can compute the pointwise log-priors of each sampled random variable
with [`varwise_logpriors`](@ref).
Differently from `pointwise_loglikelihoods` it reports only a
single value for `.~` assignements.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
single value for `.~` assignements.
single value for `.~` assignements.

function TestLogModifyingChildContext(
mod=1.2,
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(),
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()),
#OrderedDict{VarName,Vector{Float64}}(),PriorContext()),

Comment on lines +1058 to +1060
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(
mod, context
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(
mod, context
)
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)

function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
#@info "TestLogModifyingChildContext tilde_assume!! called for $vn"
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, logp*context.mod, vi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return value, logp*context.mod, vi
return value, logp * context.mod, vi

value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, logp*context.mod, vi
end
function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, right, left, vn, vi)
function DynamicPPL.dot_tilde_assume(
context::TestLogModifyingChildContext, right, left, vn, vi
)

Comment on lines +77 to +78
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939]
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939]
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
arr0 = [
5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;;
3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939
]
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")])

chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
tmp1 = pointwise_logdensities(model, chain)
vi = VarInfo(model)
i_sample, i_chain = (1,2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
i_sample, i_chain = (1,2)
i_sample, i_chain = (1, 2)

vi = VarInfo(model)
i_sample, i_chain = (1,2)
DynamicPPL.setval!(vi, chain, i_sample, i_chain)
lp1 = DynamicPPL.pointwise_logdensities(model, vi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
lp1 = DynamicPPL.pointwise_logdensities(model, vi)
lp1 = DynamicPPL.pointwise_logdensities(model, vi)

lp1 = DynamicPPL.pointwise_logdensities(model, vi)
# k = first(keys(lp1))
for k in keys(lp1)
@test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1]
@test tmp1[string(k)][i_sample, i_chain] .≈ lp1[k][1]

for k in keys(lp1)
@test tmp1[string(k)][i_sample,i_chain] .≈ lp1[k][1]
end
end;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end;
end;

@torfjelde torfjelde mentioned this pull request Sep 23, 2024
bgctw added a commit to bgctw/DynamicPPL.jl that referenced this pull request Sep 24, 2024
github-merge-queue bot pushed a commit that referenced this pull request Sep 30, 2024
* implement pointwise_logpriors

* implement varwise_logpriors

* remove pointwise_logpriors

* revert dot_assume to not explicitly resolve components of sum

* docstring varwise_logpriores

use loop for prior in example

Unfortunately cannot make it a jldoctest, because relies on Turing for sampling

* integrate pointwise_loglikelihoods and varwise_logpriors by pointwise_densities

* record single prior components

by forwarding dot_tilde_assume to tilde_assume

* forward dot_tilde_assume to tilde_assume   for Multivariate

* avoid recording prior components on leaf-prior-context

and avoid recording likelihoods when invoked with leaf-Likelihood context

* undeprecate pointwise_loglikelihoods and implement pointwise_prior_logdensities

mostly taken from #669

* drop vi instead of re-compute vi

bgctw first forwared dot_tilde_assume to get a correct vi
and then recomputed it for recording component prior densities.

Replaced this by the Hack of torfjelde that completely drops vi and recombines the value, so that assume is called only once for each varName,

* include docstrings of pointwise_logdensities
pointwise_prior_logdensities
int api.md docu

* Update src/pointwise_logdensities.jl remove commented code

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update src/pointwise_logdensities.jl remove commented code

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update test/pointwise_logdensities.jl rename m to model

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update test/pointwise_logdensities.jl remove unused code

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update test/pointwise_logdensities.jl rename m to model

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update test/pointwise_logdensities.jl rename m to model

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Update src/test_utils.jl remove old code

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* rename m to model

* JuliaFormatter

* Update test/runtests.jl remove interactive code

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* remove demo_dot_assume_matrix_dot_observe_matrix2 testcase

testing  higher dimensions better left for other  PR

* ignore local interactive development code

* ignore temporary directory holding local interactive development code

* Apply suggestions from code review: clean up comments and Imports

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Apply suggestions from code review: change test of applying to chains on already used model

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* fix test on names in likelihood components

to work with literal models

* try to fix testset pointwise_logdensities chain

* Update test/pointwise_logdensities.jl

* Update .gitignore

* Formtating

* Fixed tests

* Updated docs for `pointwise_logdensities` + made it a doctest not
dependent on Turing.jl

* Bump patch version

* Remove blank line from `@model` in doctest to see if that fixes the
parsing issues

* Added doctest filter to handle the `;;]` at the end of lines for matrices

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
@torfjelde torfjelde closed this Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants